{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp callback.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Callbacks\n",
    "\n",
    "> Callbacks which work with a learner's data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.basics import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastai.test_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class CollectDataCallback(Callback):\n",
    "    \"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing\"\n",
    "    def before_fit(self): self.data = L()\n",
    "    def after_batch(self): \n",
    "        self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates()\n",
    "class WeightedDL(TfmdDL):\n",
    "    def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        wgts = array([1.]*len(dataset) if wgts is None else wgts)\n",
    "        self.wgts = wgts/wgts.sum()\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.n==0: return []\n",
    "        if not self.shuffle: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.n, p=self.wgts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 160\n",
    "dsets = Datasets(torch.arange(n).float())\n",
    "dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)\n",
    "learn = synth_learner(data=dls, cbs=CollectDataCallback)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, nan, None, '00:01']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAANcklEQVR4nO3db4xld13H8fdHtlShGLbutK5tcWpTiDxhS9YKVkmhgP1DKDwwoRGyRszyAAwo/tlCovisRSg+MehiK42UmqYUaCgKtRIJiSlOaylbl1r+LLBl6U5DIqgPpPD1wT0Lk2Fm5+7cO3Pvt7xfyeSe8ztn5nwyO/PZc3/3nDupKiRJ/fzErANIkjbHApekpixwSWrKApekpixwSWpqx3YebNeuXbW4uLidh5Sk9u67777Hq2ph9fi2Fvji4iJLS0vbeUhJai/JV9cadwpFkpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpra1jsxJc2PxQN3zezYR667ambHfjLxDFySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJakpC1ySmrLAJampDQs8yXlJPpXkcJKHkrx5GH9HkkeTPDB8XLn1cSVJJ4zzNzGfAN5aVfcneQZwX5K7h23vqap3bV08SdJ6NizwqjoGHBuWv5PkMHDOVgeTJJ3cKc2BJ1kELgLuHYbelOTBJDcl2bnO5+xPspRkaXl5ebK0kqQfGLvAk5wBfAh4S1V9G3gvcAGwh9EZ+rvX+ryqOlhVe6tq78LCwuSJJUnAmAWe5DRG5X1LVd0BUFWPVdX3qur7wPuAi7cupiRptXGuQglwI3C4qm5YMb57xW6vBg5NP54kaT3jXIVyCfA64PNJHhjG3gZck2QPUMAR4A1bkE+StI5xrkL5DJA1Nn18+nEkSePyTkxJamqcKRRJW2jxwF2zjvBjY5bf6yPXXTX1r+kZuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMbFniS85J8KsnhJA8lefMwfmaSu5M8Mjzu3Pq4kqQTxjkDfwJ4a1X9IvAC4I1JngscAO6pqguBe4Z1SdI22bDAq+pYVd0/LH8HOAycA1wN3DzsdjPwqi3KKElawynNgSdZBC4C7gXOrqpjMCp54Kx1Pmd/kqUkS8vLyxPGlSSdMHaBJzkD+BDwlqr69rifV1UHq2pvVe1dWFjYTEZJ0hrGKvAkpzEq71uq6o5h+LEku4ftu4HjWxNRkrSWca5CCXAjcLiqblix6U5g37C8D/jo9ONJktazY4x9LgFeB3w+yQPD2NuA64Dbkrwe+BrwG1uSUJK0pg0LvKo+A2SdzZdNN44kaVzeiSlJTY0zhSJJU7V44K5ZR3hS8AxckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKd9OVsK3N1VPnoFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMWuCQ1ZYFLUlMbFniSm5IcT3Joxdg7kjya5IHh48qtjSlJWm2cM/D3A5evMf6eqtozfHx8urEkSRvZsMCr6tPAt7YhiyTpFEwyB/6mJA8OUyw719spyf4kS0mWlpeXJzicJGmlzRb4e4ELgD3AMeDd6+1YVQeram9V7V1YWNjk4SRJq22qwKvqsar6XlV9H3gfcPF0Y0mSNrKpAk+ye8Xqq4FD6+0rSdoaG/5JtSS3ApcCu5IcBf4UuDTJHqCAI8Abti6iJGktGxZ4VV2zxvCNW5BFknQKvBNTkpryr9LrR8zyL7Qfue6qmR1b6sYzcElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKYscElqygKXpKZ8O1nNlVm+la3UjWfgktSUBS5JTVngktSUBS5JTVngktSUBS5JTXkZ4RzzkjpJJ+MZuCQ1ZYFLUlMWuCQ1ZYFLUlMbFniSm5IcT3JoxdiZSe5O8sjwuHNrY0qSVhvnDPz9wOWrxg4A91TVhcA9w7okaRttWOBV9WngW6uGrwZuHpZvBl413ViSpI1sdg787Ko6BjA8nrXejkn2J1lKsrS8vLzJw0mSVtvyFzGr6mBV7a2qvQsLC1t9OEn6sbHZAn8syW6A4fH49CJJksax2QK/E9g3LO8DPjqdOJKkcY1zGeGtwL8Cz0lyNMnrgeuAlyV5BHjZsC5J2kYbvplVVV2zzqbLppxFknQKvBNTkpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpqywCWpKQtckpraMesA41o8cNfMjn3kuqtmdmxJWo9n4JLUlAUuSU1Z4JLUlAUuSU1N9CJmkiPAd4DvAU9U1d5phJIkbWwaV6G8uKoen8LXkSSdAqdQJKmpSc/AC/hkkgL+uqoOrt4hyX5gP8CznvWsCQ83G7O8Bl2S1jPpGfglVfV84ArgjUletHqHqjpYVXurau/CwsKEh5MknTBRgVfVN4bH48CHgYunEUqStLFNF3iSpyd5xoll4OXAoWkFkySd3CRz4GcDH05y4ut8sKr+cSqpJEkb2nSBV9WXgedNMYsk6RR4GaEkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTFrgkNWWBS1JTExV4ksuTPJzki0kOTCuUJGljmy7wJE8B/hK4AngucE2S504rmCTp5CY5A78Y+GJVfbmq/g/4e+Dq6cSSJG1kxwSfew7w9RXrR4FfXr1Tkv3A/mH1v5M8PMExx7ULeHwbjjMpc05fl6zmnK65z5nrf7C4maw/v9bgJAWeNcbqRwaqDgIHJzjOKUuyVFV7t/OYm2HO6euS1ZzT1SUnTDfrJFMoR4HzVqyfC3xjsjiSpHFNUuD/BlyY5PwkTwVeA9w5nViSpI1segqlqp5I8ibgE8BTgJuq6qGpJZvMtk7ZTMCc09clqzmnq0tOmGLWVP3ItLUkqQHvxJSkpixwSWqqdYEnOS/Jp5IcTvJQkjcP42cmuTvJI8PjzllnhdHdq0n+PcnHhvV5zfnMJLcn+cLwvX3hPGZN8nvDv/uhJLcm+cl5yJnkpiTHkxxaMbZuriTXDm9H8XCSX5+DrH8+/Ns/mOTDSZ4566xr5Vyx7Q+SVJJd85ozye8OWR5K8s6p5ayqth/AbuD5w/IzgP9kdFv/O4EDw/gB4PpZZx2y/D7wQeBjw/q85rwZ+J1h+anAM+ctK6Mbyb4C/NSwfhvwW/OQE3gR8Hzg0IqxNXMNP6+fA04Hzge+BDxlxllfDuwYlq+fh6xr5RzGz2N0IcVXgV3zmBN4MfBPwOnD+lnTyrmtP9jb8M37KPAy4GFg9zC2G3h4DrKdC9wDvGRFgc9jzp8eijGrxucqKz+8E/hMRldTfWwonrnICSyu+iVeMxdwLXDtiv0+AbxwlllXbXs1cMs8ZF0rJ3A78DzgyIoCn6ucjE4uXrrGfhPnbD2FslKSReAi4F7g7Ko6BjA8njXDaCf8BfBHwPdXjM1jzl8AloG/HaZ7/ibJ05mzrFX1KPAu4GvAMeC/quqTzFnOFdbLtdZbUpyzzdlO5reBfxiW5yprklcCj1bV51ZtmqucwLOBX0tyb5J/SfJLw/jEOZ8UBZ7kDOBDwFuq6tuzzrNaklcAx6vqvllnGcMORk8B31tVFwH/w+gp/1wZ5pCvZvTU8+eApyd57WxTbcpYb0kxC0neDjwB3HJiaI3dZpI1ydOAtwN/stbmNcZm+T3dAewEXgD8IXBbkjCFnO0LPMlpjMr7lqq6Yxh+LMnuYftu4Pis8g0uAV6Z5Aijd218SZIPMH85YXQWcLSq7h3Wb2dU6POW9aXAV6pquaq+C9wB/Arzl/OE9XLN5VtSJNkHvAL4zRqe3zNfWS9g9J/354bfq3OB+5P8LPOVE0Z57qiRzzJ6Fr6LKeRsXeDD/2I3Aoer6oYVm+4E9g3L+xjNjc9MVV1bVedW1SKjtxz456p6LXOWE6Cqvgl8PclzhqHLgP9g/rJ+DXhBkqcNPweXAYeZv5wnrJfrTuA1SU5Pcj5wIfDZGeT7gSSXA38MvLKq/nfFprnJWlWfr6qzqmpx+L06yuiChm/OU87BRxi99kWSZzO6MOBxppFzuyb2t+jFgl9l9JTjQeCB4eNK4GcYvWD4yPB45qyzrsh8KT98EXMucwJ7gKXh+/oRRk//5i4r8GfAF4BDwN8xejV/5jmBWxnNy3+XUbG8/mS5GE0FfInRC51XzEHWLzKamz3xO/VXs866Vs5V248wvIg5bzkZFfYHhp/T+4GXTCunt9JLUlOtp1Ak6ceZBS5JTVngktSUBS5JTVngktSUBS5JTVngktTU/wMIVnD4cdC3sgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit(1)\n",
    "t = concat(*learn.collect_data.data.itemgot(0,0))\n",
    "plt.hist(t.numpy());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates()\n",
    "class PartialDL(TfmdDL):\n",
    "    \"Select randomly partial quantity of data at each epoch\"\n",
    "    def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):\n",
    "        super().__init__(dataset=dataset, bs=bs, **kwargs)\n",
    "        self.partial_n = min(partial_n, self.n) if partial_n else None\n",
    "\n",
    "    def get_idxs(self):\n",
    "        if self.partial_n is None: return super().get_idxs()\n",
    "        return list(np.random.choice(self.n, self.partial_n, replace=False))\n",
    "\n",
    "    def __len__(self):\n",
    "        if self.partial_n is None: return super().__len__()\n",
    "        return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@patch\n",
    "@delegates(Datasets.dataloaders)\n",
    "def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):\n",
    "    \"Create a partial dataloader `PartialDL` for the training set\"\n",
    "    xtra_kwargs = [{}] * (self.n_subsets-1)\n",
    "    return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = dsets.partial_dataloaders(partial_n=32, bs=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(dls[0])==2\n",
    "for batch in dls[0]:\n",
    "    assert len(batch[0])==16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
